import os
import pandas as pd
import jieba
from numpy import *
import numpy as np
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from keras.utils import to_categorical
from keras.layers import Dense, Input, Flatten, Dropout
from keras.layers import Conv1D, MaxPooling1D, Embedding
from keras.models import Sequential
import matplotlib.pyplot as plt

#预定义变量
MAX_SEQUENCE_LENGTH = 100    #最大序列长度
EMBEDDING_DIM = 100    #embdding 维度
VALIDATION_SPLIT = 0.05    #验证集比例
TEST_SPLIT = 0.1    #测试集比例



def readFlie(path):    #读取一个样本的记录，默认一个文件一条样本
    with open(path,'r',errors='ignore') as file:
        content = file.read()
        file.close()
        return content

def readList(path):    #读取一个csv，输出样本列表，元素形式[{"title"}...]
    data = pd.read_csv(path,encoding='utf-8')["标题"]
    # data.dropna(inplace=True)
    data.dropna(inplace=True)
    col = data
    # print(type(col))
    sample_list = np.array(col)
    return sample_list

def getStopWord(inputFile):  #获取停用词表
    stopWordList = readFlie(inputFile).splitlines()
    return stopWordList

def saveFile(path,result):
    with open(path, 'w', errors='ignore') as file:
        file.write(result)
        file.close()

def cate2Num(cate):
    return {
        '国家级': 0,
        '省区级': 1,
        '地区级': 2,
    }.get(cate, 'error')

def segText(inputPath):
    fatherLists = os.listdir(inputPath)  # 主目录
    total = []
    for eachDir in fatherLists:  # 遍历主目录中各个文件夹
        eachPath = inputPath + eachDir + "/"  # 保存主目录中每个文件夹目录，便于遍历二级文件
        childLists = os.listdir(eachPath)  # 获取每个文件夹中的各个文件
        for eachFile in childLists:  #遍历数据集下的csv文件
            sentences = []
            eachPathFile = eachPath + eachFile
            content = readList(eachPathFile)
            if(eachDir == '国家级'):
                random.shuffle(content)
                content = content[:600]
            if (eachDir == '省区级'):
                random.shuffle(content)
                content = content[:600]
            result = content.tolist()  #tolist()方法将数组或者矩阵转换成列表
            # print(result)
            # sentences = []
            cate = cate2Num(eachDir)
            # print(cate)
            preprocess_text(result, sentences, cate,stopWord_path)
            for sentence in sentences:
                # print(sentence[1])
                total.append((" ".join(sentence[0]), sentence[1]))   #合并三文件为一个list
    # 打散数据，生成更可靠的训练集
    random.shuffle(total)
    all_texts = [single[0] for single in total]
    all_labels = [single[1] for single in total]
    sequence_padding(all_texts,all_labels)

def preprocess_text(content_lines, sentences, category,stopWordList):
    for line in content_lines:
        try:
            stopwords = getStopWord(stopWordList)
            segs = jieba.lcut(line)
            segs = [v for v in segs if not str(v).isdigit()]  # 去数字
            segs = list(filter(lambda x: x.strip(), segs))  # 去左右空格
            segs = list(filter(lambda x: len(x) > 1, segs))  # 长度为1的字符
            segs = list(filter(lambda x: x not in stopwords, segs))  # 去掉停用词
            sentences.append((" ".join(segs), category))  # 打标签
        except Exception:
            print(line)  #输出分词结果
            continue
            # 调用函数、生成训练数据

def sequence_padding(text,label):
    tokenizer = Tokenizer()
    tokenizer.fit_on_texts(text)
    sequences = tokenizer.texts_to_sequences(text)
    word_index = tokenizer.word_index
    print('Found %s unique tokens.' % len(word_index))
    data = pad_sequences(sequences, maxlen=MAX_SEQUENCE_LENGTH)
    labels = to_categorical(np.asarray(label))
    print('Shape of data tensor:', data.shape)
    print('Shape of label tensor:', labels.shape)
    spilit_list = data_spilit(data,labels)
    handle_model(word_index,labels,spilit_list)

def data_spilit(data,label):
    p1 = int(len(data) * (1 - VALIDATION_SPLIT - TEST_SPLIT))  #训练集比例   
    p2 = int(len(data) * (1 - TEST_SPLIT))
    print(p1,p2)
    x_train = data[:p1]
    y_train = label[:p1]
    x_val = data[p1:p2]
    y_val = label[p1:p2]
    x_test = data[p2:]
    y_test = label[p2:]
    return x_train,y_train,x_val,y_val,x_test,y_test

def handle_model(word_index,labels,list):
    # LSTM训练模型
    model = Sequential()
    model.add(Embedding(len(word_index) + 1, EMBEDDING_DIM, input_length=MAX_SEQUENCE_LENGTH))
    model.add(Dropout(0.2))
    model.add(Conv1D(100, 3, padding='valid', activation='relu', strides=1))
    model.add(MaxPooling1D(3))
    model.add(Flatten())
    model.add(Dense(100, activation='relu'))
    model.add(Dense(labels.shape[1], activation='softmax'))
    model.summary()

    # 模型编译
    model.compile(loss='categorical_crossentropy',
                  optimizer='rmsprop',
                  metrics=['acc'])
    # print(model.metrics_names)  # 评估结果返回值的标签
    history = model.fit(list[0], list[1], validation_data=(list[2], list[3]), epochs=10, batch_size=128)
    model.save('./model/cnn.h5')
    # 模型评估
    print(model.evaluate(list[4], list[5]))
    # 准确率可视化输出
    plt.title('Accuracy')
    plt.plot(history.history['acc'], label='train')
    plt.plot(history.history['val_acc'], label='test')
    plt.legend()
    plt.show();

#加载语料
if __name__ == '__main__':
    datapath = "./total/"  # 原始训练集路径
    stopWord_path = "./stop/stopword.txt"  # 停用词路径
    test_path = "./test/"  # 测试集路径

    split_datapath = "./split/split_data/"  # 对原始训练集分词之后的数据路径
    test_split_path = "./split/test_split/"  # 测试集分词路径

    # 输入训练集
    segText(datapath)  # 读入数据

